{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test1 data type is  (28, 28) 2    2\n",
      "0  1\n",
      "original data type is  (10000, 28, 28) (10000,)\n",
      "(10000, 10) (10000, 784)\n",
      "INFO:tensorflow:Restoring parameters from ./model\\mininst.ckpt\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAANxUlEQVR4nO3de4xU93nG8ecBc7EwtqFgSjGygwOycSpDsiJx3YstN6nDH8GRckOJgyNHpGrcJhJSYrmV4igXWVVst1WjVCRGIZUvcn2JqWIlJsSR6wRhLy4BbJJAXOpgVmDEpuBWhd312z/2UG3wzpll5sycMe/3I41m5rxzznk18OyZmd+c+TkiBODsN6nuBgB0B2EHkiDsQBKEHUiCsANJnNPNnU31tJiuGd3cJZDK/+q/dTJOeLxaW2G3fYOkv5c0WdK3IuLOssdP1wy909e3s0sAJbbFloa1ll/G254s6euS3itpqaTVtpe2uj0AndXOe/YVkvZFxEsRcVLSg5JWVdMWgKq1E/YFkn495v6BYtlvsb3Wdr/t/iGdaGN3ANrRTtjH+xDgDd+9jYj1EdEXEX1TNK2N3QFoRzthPyBp4Zj7F0s62F47ADqlnbA/J2mx7bfYnirpI5I2VdMWgKq1PPQWEcO2b5X0A40OvW2IiBcq6wxApdoaZ4+IJyQ9UVEvADqIr8sCSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJEHYgSQIO5BEV39KGq3Z/+WrS+sj0xtPzjn3yldL19161SMt9XTKZT/6RGl95rPnNqzN+4eftrVvnBmO7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBOPsPWDwe4tL67uX/WPH9j3UeIh+Qn5+3bdK6/f1zW9Ye2jzn5SuO7Jnb0s9YXwc2YEkCDuQBGEHkiDsQBKEHUiCsANJEHYgCcbZu6DZOPpPlj3YsX3/028Wldbv3vru0vqll5SfD//k0kdL6x+dOdCw9pWb55Suu+jzjLNXqa2w294v6bikEUnDEdFXRVMAqlfFkf26iDhSwXYAdBDv2YEk2g17SHrS9nbba8d7gO21tvtt9w/pRJu7A9Cqdl/GXxMRB21fJGmz7Z9HxNNjHxAR6yWtl6TzPbvN0y4AtKqtI3tEHCyuD0t6TNKKKpoCUL2Ww257hu2Zp25Leo+k3VU1BqBa7byMnyfpMduntnN/RHy/kq7eZIavf0dp/UdXfb3JFqaUVv9ucElp/akPl4x4Hjxcuu6Swf7S+qTp00vrX932+6X12+fsalgbnjVcui6q1XLYI+IlSVdV2AuADmLoDUiCsANJEHYgCcIOJEHYgSQ4xbUCry2YWlqf1ORvarOhtR+/r3x4a+SlX5TW27Hvi8tL6/fPvqvJFqY1rFz8fY413cSzDSRB2IEkCDuQBGEHkiDsQBKEHUiCsANJMM5egQu/s7W0/oH+j5XWPXistD48sP8MO6rOJ1f+sLR+3qTG4+joLRzZgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJxtm7YOTFX9bdQkP7v3J1af2WC7/WZAvlPzW9buBdDWszf7indN2RJnvGmeHIDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJMM5+lvvNTeXj6D/5ePk4+gWTysfRt56YXFrf8eXGvzt/7rFnS9dFtZoe2W1vsH3Y9u4xy2bb3mx7b3E9q7NtAmjXRF7Gf1vSDactu03SlohYLGlLcR9AD2sa9oh4WtLR0xavkrSxuL1R0o0V9wWgYq1+QDcvIgYkqbi+qNEDba+13W+7f0gnWtwdgHZ1/NP4iFgfEX0R0TelZJI/AJ3VatgP2Z4vScX14epaAtAJrYZ9k6Q1xe01kh6vph0AndJ0nN32A5KulTTH9gFJX5B0p6SHbN8i6WVJH+xkk2jdkbdHab3ZOHoza378ydL6ku8ylt4rmoY9IlY3KF1fcS8AOoivywJJEHYgCcIOJEHYgSQIO5AEp7ieBU5uvqRhbevldzVZu3zo7aqta0rrV6z7VWmdn4PuHRzZgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJxtnfBM5ZdGlp/Utv/ZeGtVlNTmHd3uSXwi75UvlI+cjgYPkG0DM4sgNJEHYgCcIOJEHYgSQIO5AEYQeSIOxAEoyzvwlc9tArpfXlU1v/m716y5+X1pf87LmWt43ewpEdSIKwA0kQdiAJwg4kQdiBJAg7kARhB5JgnL0HDK65urT+xXnNfvt9WsPKmv1/WrrmFZ/bV1rnd9/PHk2P7LY32D5se/eYZXfYfsX2juKysrNtAmjXRF7Gf1vSDeMsvycilhWXJ6ptC0DVmoY9Ip6WdLQLvQDooHY+oLvV9s7iZf6sRg+yvdZ2v+3+ITX5wTMAHdNq2L8h6TJJyyQNSGr4CVJErI+Ivojom1LyQRKAzmop7BFxKCJGIuJ1Sd+UtKLatgBUraWw254/5u77Je1u9FgAvaHpOLvtByRdK2mO7QOSviDpWtvLJIWk/ZI+1cEe3/TOWfB7pfU/+qttpfXzJrX+9mfri28trS8Z5Hz1LJqGPSJWj7P43g70AqCD+LoskARhB5Ig7EAShB1IgrADSXCKaxfsuX1haf27v/uvbW3/ul0fbFjjFFacwpEdSIKwA0kQdiAJwg4kQdiBJAg7kARhB5JgnL0Ltr/vniaPaO8XfC74i9cb1oYHB9vaNs4eHNmBJAg7kARhB5Ig7EAShB1IgrADSRB2IAnG2c8CQ/MuaFibcnJBFzt5o5FXjzSsxYny6cA8rfz7B5PnzmmpJ0kamXthaX3vuqktb3siYsQNa5f/ZZPfIDh2rKV9cmQHkiDsQBKEHUiCsANJEHYgCcIOJEHYgSQYZz8LfO/hDXW30NAf/Pt4kwCPOnLo/NJ1Z809Xlrf9o77W+qp1y39m1tL64s+t7Wl7TY9stteaPsp23tsv2D7M8Xy2bY3295bXM9qqQMAXTGRl/HDktZFxBWS3iXp07aXSrpN0paIWCxpS3EfQI9qGvaIGIiI54vbxyXtkbRA0ipJG4uHbZR0Y6eaBNC+M/qAzvalkpZL2iZpXkQMSKN/ECRd1GCdtbb7bfcPqfy70AA6Z8Jht32epEckfTYiJvxN/IhYHxF9EdE3pc0fVgTQugmF3fYUjQb9voh4tFh8yPb8oj5f0uHOtAigCk2H3mxb0r2S9kTE3WNKmyStkXRncf14Rzo8C6x68aOl9S1ve7hLnXTfT5c/UNu+/ydONqwNReOf356IlTtvLq3/147WT79d8Mxwy+uWmcg4+zWSbpK0y/aOYtntGg35Q7ZvkfSypMaThAOoXdOwR8QzkhqdaX99te0A6BS+LgskQdiBJAg7kARhB5Ig7EASnOLaBef+2X+U1q/8avkpjdHBf6WZlx8trXfyNNIr/+0TpfV4eUZb21/08GuNi8/uamvbs7S3rXodOLIDSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKOiK7t7HzPjneaE+WATtkWW3Qsjo57lipHdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUiiadhtL7T9lO09tl+w/Zli+R22X7G9o7is7Hy7AFo1kekHhiWti4jnbc+UtN325qJ2T0R8rXPtAajKROZnH5A0UNw+bnuPpAWdbgxAtc7oPbvtSyUtl7StWHSr7Z22N9ie1WCdtbb7bfcP6URbzQJo3YTDbvs8SY9I+mxEHJP0DUmXSVqm0SP/XeOtFxHrI6IvIvqmaFoFLQNoxYTCbnuKRoN+X0Q8KkkRcSgiRiLidUnflLSic20CaNdEPo23pHsl7YmIu8csnz/mYe+XtLv69gBUZSKfxl8j6SZJu2zvKJbdLmm17WWSQtJ+SZ/qSIcAKjGRT+OfkTTe71A/UX07ADqFb9ABSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeScER0b2f2q5L+c8yiOZKOdK2BM9OrvfVqXxK9tarK3i6JiLnjFboa9jfs3O6PiL7aGijRq731al8SvbWqW73xMh5IgrADSdQd9vU1779Mr/bWq31J9NaqrvRW63t2AN1T95EdQJcQdiCJWsJu+wbbv7C9z/ZtdfTQiO39tncV01D319zLBtuHbe8es2y27c229xbX486xV1NvPTGNd8k047U+d3VPf9719+y2J0v6paR3Szog6TlJqyPixa420oDt/ZL6IqL2L2DY/mNJr0n6TkS8rVj2t5KORsSdxR/KWRHx+R7p7Q5Jr9U9jXcxW9H8sdOMS7pR0s2q8bkr6etD6sLzVseRfYWkfRHxUkSclPSgpFU19NHzIuJpSUdPW7xK0sbi9kaN/mfpuga99YSIGIiI54vbxyWdmma81ueupK+uqCPsCyT9esz9A+qt+d5D0pO2t9teW3cz45gXEQPS6H8eSRfV3M/pmk7j3U2nTTPeM89dK9Oft6uOsI83lVQvjf9dExFvl/ReSZ8uXq5iYiY0jXe3jDPNeE9odfrzdtUR9gOSFo65f7GkgzX0Ma6IOFhcH5b0mHpvKupDp2bQLa4P19zP/+ulabzHm2ZcPfDc1Tn9eR1hf07SYttvsT1V0kckbaqhjzewPaP44ES2Z0h6j3pvKupNktYUt9dIerzGXn5Lr0zj3WiacdX83NU+/XlEdP0iaaVGP5H/laS/rqOHBn0tkvSz4vJC3b1JekCjL+uGNPqK6BZJvyNpi6S9xfXsHurtnyXtkrRTo8GaX1Nvf6jRt4Y7Je0oLivrfu5K+urK88bXZYEk+AYdkARhB5Ig7EAShB1IgrADSRB2IAnCDiTxfy43Cn7d/BIFAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "import tensorflow.compat.v1 as tf\n",
    "tf.disable_v2_behavior()\n",
    "import os\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy\n",
    "import tensorflow.keras as keras\n",
    "import pandas as pd\n",
    "import argparse \n",
    "\n",
    "(x,y), (x_test,y_test) = keras.datasets.mnist.load_data()  #numpy中的格式\n",
    "test_x = x_test[1,:,:]\n",
    "test_y = y_test[1]\n",
    "print(\"test1 data type is \",test_x.shape, test_y,pd.get_dummies(test_y))    #y为vector\n",
    "\n",
    "plt.imshow(test_x)\n",
    "print(\"original data type is \",x_test.shape, y_test.shape)    #y为vector\n",
    "x_test = x_test.reshape(10000,784)\n",
    "y_test = pd.get_dummies(y_test)\n",
    "\n",
    "print(y_test.shape,x_test.shape)\n",
    "saver = tf.train.import_meta_graph('./model/mininst.ckpt.meta')\n",
    "# 获取模型参数用于手动构建模型\n",
    "with tf.Session() as sess:\n",
    "    saver.restore(sess, tf.train.latest_checkpoint(\"./model\"))\n",
    "    weights = {\n",
    "        'h1': sess.graph.get_operation_by_name(\"h1\"),\n",
    "        'h2': sess.graph.get_operation_by_name(\"h2\"),\n",
    "        'out': sess.graph.get_operation_by_name(\"outh\")\n",
    "    }\n",
    "    biases = {\n",
    "        'b1': sess.graph.get_operation_by_name(\"b1\"),\n",
    "        'b2': sess.graph.get_operation_by_name(\"b2\"),\n",
    "        'out': sess.graph.get_operation_by_name(\"outb\")\n",
    "    }\n",
    "    init = tf.global_variables_initializer()\n",
    "    sess.run(init)\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(10000, 784)\n",
      "Tensor(\"Const_22:0\", shape=(10000, 784), dtype=uint8)\n",
      "[7 2 1 ... 4 5 6] [7 2 1 ... 4 5 6]\n",
      "0.8817\n"
     ]
    }
   ],
   "source": [
    "# 使用freeze脚本固化后的模型\n",
    "\n",
    "def load_graph(frozen_graph_filename):\n",
    "    # We parse the graph_def file\n",
    "    with tf.gfile.GFile(frozen_graph_filename, \"rb\") as f:\n",
    "        graph_def = tf.GraphDef()\n",
    "        graph_def.ParseFromString(f.read())\n",
    "\n",
    "    # We load the graph_def in the default graph\n",
    "    with tf.Graph().as_default() as graph:\n",
    "        tf.import_graph_def(\n",
    "            graph_def, \n",
    "            input_map=None, \n",
    "            return_elements=None, \n",
    "            name=\"prefix\", \n",
    "            op_dict=None, \n",
    "            producer_op_list=None\n",
    "        )\n",
    "    return graph\n",
    "\n",
    "#加载已经将参数固化后的图\n",
    "graph= load_graph('frozen_model.pb')\n",
    "# We can list operations\n",
    "#op.values() gives you a list of tensors it produces\n",
    "#op.name gives you the name\n",
    "#输入,输出结点也是operation,所以,我们可以得到operation的名字\n",
    "# for op in graph.get_operations():\n",
    "#     print(op.name,op.values())\n",
    "#操作有:prefix/Placeholder/inputs_placeholder\n",
    "#操作有:prefix/Accuracy/predictions\n",
    "#为了预测,我们需要找到我们需要feed的tensor,那么就需要该tensor的名字\n",
    "#注意prefix/Placeholder/inputs_placeholder仅仅是操作的名字,prefix/Placeholder/inputs_placeholder:0才是tensor的名字\n",
    "x = graph.get_tensor_by_name('prefix/X:0')\n",
    "y = graph.get_tensor_by_name('prefix/output:0')\n",
    "print(x_test.shape)\n",
    "x_tensor = tf.convert_to_tensor(x_test)\n",
    "print(x_tensor)\n",
    "with tf.Session(graph=graph) as sess:\n",
    "    y_out = sess.run(y, feed_dict={\n",
    "        x: x_test\n",
    "    })\n",
    "#     correct_pred = tf.equal(tf.argmax(y_out, 1), tf.argmax(y_test.values, 1))\n",
    "label_test = numpy.argmax(y_out, 1)\n",
    "label = numpy.argmax(y_test.values, 1)\n",
    "\n",
    "print(label_test,label)\n",
    "correct_prediction = numpy.equal(label_test,label)\n",
    "print(numpy.mean(correct_prediction))\n",
    "\n",
    "# 查看错误率\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
